Full basic implementation according to the white-paper.
"U-Net: Convolutional Networks for Biomedical Image Segmentation"
https://arxiv.org/pdf/1505.04597.pdf
import datetime
import pathlib
import sys
import random
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display
from PIL import Image, ImageDraw
from sklearn.model_selection import train_test_split
%load_ext autoreload
%autoreload 2
Image.MAX_IMAGE_PIXELS = None
SEED=241
random.seed(SEED)
np.random.seed(SEED)
plt.rcParams["figure.figsize"] = [26,19]
def load_map_and_mask(map_file, mask_file, workspace_dir):
map_img = Image.open(str(map_file))
polygons = []
with open(str(mask_file)) as f:
for line in f:
line = line.strip()
points = line.split(' ')
polygon = [(int(xy[0]), int(xy[1])) for xy in [point.split(',') for point in points]]
polygons.append(polygon)
mask_img = Image.new('1', map_img.size, 0)
for p in polygons:
ImageDraw.Draw(mask_img).polygon(p, fill=1)
mask_img.convert('RGB').save(str(workspace_dir/(mask_file.stem + '.jpg')), format='JPEG', quality=100)
return map_img, mask_img
def plot_masks(map_imgs, mask_imgs):
rows = len(map_imgs)
cols = 2
for i, m in enumerate(map_imgs):
plt.subplot(rows,cols, cols*i +1)
plt.imshow(m)
plt.subplot(rows,cols,cols*i +2)
plt.imshow(mask_imgs[i])
plt.show()
def map_stats(map_img, mask_img):
m = np.array(mask_img).astype(np.byte)
n = np.sum(m == 0)
k = np.sum(m == 1)
print(map_img.size)
print('zeros ratio', 0 if n == 0 else round(n/(n+k), 3))
print('ones ratio', 0 if k == 0 else round(k/(n+k), 3))
dataset_dir = pathlib.Path().cwd() / 'external_data' / 'data' / 'train'
src_images = list(dataset_dir.glob('**/*.tif'))
src_images = [img_p for img_p in src_images if '.mask.' not in img_p.name]
print('src_images', len(src_images))
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
workspace_dir = pathlib.Path().cwd() / ('workspace_' + str(timestamp))
workspace_dir.mkdir(parents=True)
src_images_and_masks = []
for img_p in src_images:
map_img, mask = load_map_and_mask(img_p, img_p.parent/(img_p.stem + '.markup.txt'), workspace_dir)
src_images_and_masks.append({
'map': map_img,
'mask': mask
})
print('Stats for', img_p)
map_stats(map_img, mask)
print()
plot_masks(
[src['map'] for src in src_images_and_masks][:2],
[src['mask'] for src in src_images_and_masks][:2]
)
def cut_map_into_tiles(map_img,
mask_img,
tile_size=100,
tile_resize=100,
tiles_count=100,
tile_prefix='',
save_tiles=True,
save_dir=None):
X = []
Y = []
width, height = map_img.size
top_left_coordinates = zip(
np.random.randint(0, width - tile_size, tiles_count),
np.random.randint(0, height - tile_size, tiles_count)
)
map_img_in_rgb = map_img.convert('RGB')
for i, (x,y) in enumerate(top_left_coordinates):
tile = map_img_in_rgb.crop( (x, y, x+tile_size, y+tile_size) )
tile_mask = mask_img.crop( (x, y, x+tile_size, y+tile_size) )
tile = tile.resize((tile_resize, tile_resize))
tile_mask = tile_mask.resize((tile_resize, tile_resize))
mp = np.array(tile)
mask = np.array(tile_mask).astype(np.byte)
X.append(mp)
Y.append(mask)
if save_tiles:
np.save(str(save_dir/(tile_prefix + 'map_' + str(i) + '.np')), mp)
np.save(str(save_dir/(tile_prefix + 'mask_' + str(i) + '.np')), mask)
return X, Y
def tiles_stats(Y):
zeros_count = 0
ones_count = 0
for y in Y:
zeros_count += np.sum((y==0))
ones_count += np.sum((y==1))
print('zeros', zeros_count)
print('ones', ones_count)
if zeros_count > 0:
print('zeros ratio', zeros_count/(ones_count + zeros_count))
if ones_count > 0:
print('ones ratio', ones_count/(ones_count + zeros_count))
print()
TILES_SIZE = 1024
TILES_COUNT = 1000
UNET_INPUT_SIZE = 256
X = []
Y = []
tiles_folder = workspace_dir / 'train_tiles'
tiles_folder.mkdir(parents=True, exist_ok=True)
for i, src in enumerate(src_images_and_masks):
x, y = cut_map_into_tiles(
src['map'],
src['mask'],
tile_size=TILES_SIZE,
tile_resize=UNET_INPUT_SIZE,
tiles_count=TILES_COUNT,
tile_prefix=str(i),
save_dir=tiles_folder
)
X += x
Y += y
print('done', i)
print('X', len(X))
print('Y', len(Y))
tiles_stats(Y)
X = np.array(X)
Y = np.array(Y)[...,np.newaxis]
print(X.shape, Y.shape)
np.save(str(workspace_dir/'X.np'), X)
np.save(str(workspace_dir/'Y.np'), Y)
def binary_mask_to_img(data):
size = data.shape[::-1]
databytes = np.packbits(data, axis=1)
return Image.frombytes(mode='1', size=size, data=databytes)
# Show train data
show_count = 4
train_maps = [Image.fromarray(x.astype('uint8'), 'RGB') for x in X[:show_count]]
train_masks = [binary_mask_to_img(y[:,:,-1]) for y in Y[:show_count]]
plt.rcParams["figure.figsize"] = [50,50]
plot_masks(train_maps, train_masks)
import tensorflow as tf
from tensorflow.python.client import device_lib
from tensorflow.keras import backend as K
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import (ModelCheckpoint, LearningRateScheduler, ModelCheckpoint, EarlyStopping,
ReduceLROnPlateau, TensorBoard, TerminateOnNaN, Callback)
from tensorflow.keras.models import load_model
from tensorflow.keras.models import model_from_json
from tensorflow.keras.preprocessing.image import ImageDataGenerator
tf.__version__
def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == 'GPU']
get_available_gpus()
def create_unet(input_sz=512):
image_input = Input((input_sz, input_sz, 3))
# contracting path (down-sampling)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(image_input)
conv2 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
conv4 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
conv6 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv6)
conv7 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
conv8 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv8)
conv9 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
conv10 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
# expansive path (up-sampling)
up_conv11 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(
UpSampling2D(size = (2,2))(conv10)
)
cancat1 = concatenate([conv8, up_conv11], axis = 3)
conv12 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(cancat1)
conv13 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv12)
up_conv14 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(
UpSampling2D(size = (2,2))(conv13)
)
cancat2 = concatenate([conv6, up_conv14], axis = 3)
conv15 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(cancat2)
conv16 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv15)
up_conv17 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(
UpSampling2D(size = (2,2))(conv16)
)
cancat3 = concatenate([conv4, up_conv17], axis = 3)
conv18 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(cancat3)
conv19 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv18)
up_conv20 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(
UpSampling2D(size = (2,2))(conv19)
)
cancat4 = concatenate([conv2, up_conv20], axis = 3)
conv21 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(cancat4)
conv22 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv21)
conv22 = SpatialDropout2D(0.2)(conv22)
conv23 = Conv2D(1, 1, 1, activation = 'sigmoid')(conv22)
return Model(inputs = [image_input], outputs = conv23)
unet = create_unet(input_sz=UNET_INPUT_SIZE)
unet.summary()
def preprocess_inputs(X):
return (2.0 / 255.0) * X - 1.0
X_preprocessed = preprocess_inputs(X)
from keras.losses import binary_crossentropy
SMOOTH = 1
def jaccard_score(gt, pr, smooth=SMOOTH, threshold=None):
"""
Jaccard index https://en.wikipedia.org/wiki/Jaccard_index
Args:
gt: ground truth 4D keras tensor (B, H, W, C)
pr: prediction 4D keras tensor (B, H, W, C)
smooth: value to avoid division by zero
threshold: value to round predictions (use ``>`` comparison),
if ``None`` prediction prediction will not be round
Returns:
IoU/Jaccard score in range [0, 1]
"""
axes = [1, 2]
if threshold is not None:
pr = K.greater(pr, threshold)
pr = K.cast(pr, K.floatx())
intersection = K.sum(gt * pr, axis=axes)
union = K.sum(gt + pr, axis=axes) - intersection
iou = (intersection + smooth) / (union + smooth)
iou = K.mean(iou, axis=0)
return iou
def create_default_callbacks(workspace_dir, batch_sz=1):
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
checkpoint_folder = workspace_dir / 'checkpoints' / str(timestamp)
checkpoint_folder.mkdir(parents=True)
tensorboard_folder = workspace_dir / 'tensorboard_logs' / str(timestamp)
checkpoint = ModelCheckpoint(
str(checkpoint_folder / 'model-{loss:.2f}.h5'),
monitor='loss',
verbose=1,
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1
)
stop = EarlyStopping(monitor='loss', patience=200, mode='min', verbose=1)
reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, min_lr=1e-9, verbose=1, mode='min')
tensorboard = TensorBoard(log_dir=str(tensorboard_folder),
histogram_freq=0,
batch_size=batch_sz,
write_graph=False,
write_grads=False,
write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None,
embeddings_data=None)
return [reduce_lr, TerminateOnNaN(), checkpoint, tensorboard], checkpoint_folder
def train(model, X_train, Y_train, workspace_dir, epochs=1, batch_sz=1):
#sgd = SGD(lr=0.01, decay=1e-6, momentum=0.99, nesterov=True)
unet.compile(
#optimizer=sgd,
optimizer='Adam',
loss=binary_crossentropy,
metrics=[jaccard_score, 'binary_accuracy']
)
callbacks, checkpoint_dir = create_default_callbacks(workspace_dir, batch_sz=batch_sz)
model_json = model.to_json()
with open(str(checkpoint_dir/'graph.json'), 'w') as json_file:
json_file.write(model_json)
return model.fit(
X_train, Y_train,
batch_size=batch_sz,
epochs=epochs,
callbacks=callbacks,
shuffle=True
)
train(unet, X_preprocessed, Y, workspace_dir, epochs=700, batch_sz=16)
import json
final_model_path = workspace_dir / 'checkpoints' / '2019-06-26-08-44-13' / 'model-0.23.h5'
with open(str(final_model_path.parent / 'graph.json'), 'r') as json_file:
fitted_model = model_from_json(json_file.read())
fitted_model.load_weights(str(final_model_path), by_name=True)
fitted_model.summary()
X_test = []
Y_test = []
for i, src in enumerate(src_images_and_masks):
x, y = cut_map_into_tiles(
src['map'],
src['mask'],
tile_size=TILES_SIZE,
tile_resize=UNET_INPUT_SIZE,
tiles_count=10,
tile_prefix=str(i),
save_tiles=False
)
X_test += x
Y_test += y
print('done', i)
def convert_grayscale_data_to_red_rgba(mask, alpha_value=50):
data = (mask * 255).astype('uint8')
alpha = (mask * alpha_value).astype('uint8')
data = data.reshape((data.shape[0], data.shape[1], 1))
npad = ((0, 0), (0, 0), (0, 2))
rgba_array = np.pad(data, pad_width=npad, mode='constant', constant_values=0)
rgba_array = np.insert(
rgba_array,
3,
alpha,
axis=2
)
return Image.fromarray(rgba_array, 'RGBA')
def apply_predicted_mask(orig_image, predicted_2d_values):
predicted_img = convert_grayscale_data_to_red_rgba(predicted_2d_values)
orig_image = Image.fromarray(orig_image.astype('uint8'), 'RGB')
orig_image = orig_image.convert('RGBA')
orig_image.paste(predicted_img, (0, 0), predicted_img)
return orig_image.convert('RGB')
for x,y in zip(X_test, Y_test):
x_np = preprocess_inputs(np.array(x))
predicted = fitted_model.predict(x_np[np.newaxis,:,:,:])
predicted_2d = predicted.reshape((predicted.shape[1], predicted.shape[2]))
display(apply_predicted_mask(x, predicted_2d))